#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   main.py    
@Contact :   raogx.vip@hotmail.com
@License :   (C)Copyright 2017-2018, Liugroup-NLPR-CASIA

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2021/7/17 9:45 下午   caijiahao      1.0         Bytedancer
'''

# import lib

import torch
from torch import nn
from torch import optim
from PIL import Image
import numpy as np

print(torch.cuda.is_available())
device = torch.device('cuda:0')
path = ''

train_X = np.empty((2000, 224, 224, 3), dtype="float32")
train_Y = np.empty((2000,), dtype="int")
train_XX = np.empty((2000, 3, 224, 224), dtype="float32")

for i in range(1000):
    file_path = path + "cat." + str(i) + ".jpg"
    image = Image.open(file_path)
    resized_image = image.resize((224, 224), Image.ANTIALIAS)
    img = np.array(resized_image)
    train_X[i, :, :, :] = img
    train_Y[i] = 0

for i in range(1000):
    file_path = path + "dog." + str(i) + ".jpg"
    image = Image.open(file_path)
    resized_image = image.resize((224, 224), Image.ANTIALIAS)
    img = np.array(resized_image)
    train_X[i + 1000, :, :, :] = img
    train_Y[i + 1000] = 1

train_X /= 255

index = np.arange(2000)
np.random.shuffle(index)

train_X = train_X[index, :, :, :]
train_Y = train_Y[index]

for i in range(3):
    train_XX[:, i, :, :] = train_X[:, :, :, i]


# 创建网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64, eps=1e-05, momentum=0.1, affine=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128, eps=1e-5, momentum=0.1, affine=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256, eps=1e-5, momentum=0.1, affine=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512, eps=1e-5, momentum=0.1, affine=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512, eps=1e-5, momentum=0.1, affine=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.dense1 = nn.Sequential(
            nn.Linear(7 * 7 * 512, 4096),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, 2)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(-1, 7 * 7 * 512)
        x = self.dense1(x)
        return x


batch_size = 16
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)

train_loss = []
for epoch in range(10):
    for i in range(2000 / batch_size):
        x = train_XX[i * batch_size:i * batch_size + batch_size]
        y = train_Y[i * batch_size:i * batch_size + batch_size]

        x = torch.from_numpy(x)
        y = torch.from_numpy(y)
        x = x.cuda()
        y = y.long().cuda()

        out = net(x)

        loss = criterion(out, y)  # 计算两者的误差
        optimizer.zero_grad()  # 清空上一步的残余更新参数值
        loss.backward()  # 误差反向传播, 计算参数更新值
        optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
        train_loss.append(loss.item())

        print(epoch, i * batch_size, np.mean(train_loss))
        train_loss = []

    total_correct = 0
    for i in range(2000):
        x = train_XX[i].reshape(1, 3, 224, 224)
        y = train_Y[i]
        x = torch.from_numpy(x)

        x = x.cuda()
        out = net(x).cpu()
        out = out.detach().numpy()
        pred = np.argmax(out)
        if pred == y:
            total_correct += 1

    acc = total_correct / 2000.0
    print('test acc:', acc)
    torch.cuda.empty_cache()
